# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import argparse
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import utils
from utils import data_utils
from utils.data_utils import item
from torch import Tensor


class SRNLossCriterion(nn.Module):

    def __init__(self, args):
        self.args = args

    def compute_loss(self, net_output, sample, reduce=True, cal_psnr=True):
        losses, other_logs = {}, {}
    
        # prepare data before computing loss
        sampled_uv = sample['sampled_uv']  # S, V, 2, N, P, P (patch-size)
        if len(sampled_uv.size()) == 4:
            sampled_uv = sampled_uv.unsqueeze(-1).unsqueeze(-1)
        S, V, _, N, P1, P2 = sampled_uv.size()
        H, W, h, w = sample['size'].tolist()
        L = N * P1 * P2 # 2048
        flatten_uv = sampled_uv.view(S, V, 2, L)
        flatten_index = (flatten_uv[:,:,0] // h + flatten_uv[:,:,1] // w * W).long()

        assert 'colors' in sample and sample['colors'] is not None, "ground-truth colors not provided"
        target_colors = sample['colors'].unsqueeze(0).unsqueeze(0) # TODO: remove, 1,1,3,N
        masks = (sample['alpha'] > 0) if self.args.no_background_loss else None #None
        if L < target_colors.size(2):    
            target_colors = target_colors.gather(2, flatten_index.unsqueeze(-1).repeat(1,1,1,3))
            masks = masks.gather(2, flatten_uv) if masks is not None else None
    
        if 'other_logs' in net_output:
            other_logs.update(net_output['other_logs'])

        # computing loss
        if self.args.color_weight > 0:
            color_loss = utils.rgb_loss(
                net_output['colors'], target_colors, 
                masks, self.args.L1)
            losses['color_loss'] = (color_loss, self.args.color_weight)
            if cal_psnr:
                psnr = data_utils.mse2psnr(color_loss)
                other_logs['PSNR']=psnr.item()
        
        if self.args.alpha_weight > 0:
            _alpha = net_output['missed'].reshape(-1)
            alpha_loss = torch.log1p(
                1. / 0.11 * _alpha.float() * (1 - _alpha.float())
            ).mean().type_as(_alpha)
            losses['alpha_loss'] = (alpha_loss, self.args.alpha_weight)

        if self.args.depth_weight > 0:
            if sample['depths'] is not None:
                target_depths = sample['depths'].unsqueeze(0).unsqueeze(0) # TODO: remove 
                target_depths = target_depths.gather(2, flatten_index)
                depth_mask = target_depths > 0
                depth_loss = utils.depth_loss(net_output['depths'], target_depths, depth_mask)
                
            else:
                # no depth map is provided, depth loss only applied on background based on masks
                max_depth_target = self.args.max_depth * torch.ones_like(net_output['depths'])
                if sample['mask'] is not None:        
                    depth_loss = utils.depth_loss(net_output['depths'], max_depth_target, (1 - sample['mask']).bool())
                else:
                    depth_loss = utils.depth_loss(net_output['depths'], max_depth_target, ~masks)
            
            depth_weight = self.args.depth_weight
            if self.args.depth_weight_decay is not None:
                final_factor, final_steps = eval(self.args.depth_weight_decay)
                depth_weight *= max(0, 1 - (1 - final_factor) * self.task._num_updates / final_steps)
                other_logs['depth_weight'] = depth_weight

            losses['depth_loss'] = (depth_loss, depth_weight)

        
        if self.args.vgg_weight > 0:
            assert P1 * P2 > 1, "we have to use a patch-based sampling for VGG loss"
            target_colors = target_colors.reshape(-1, P1, P2, 3).permute(0, 3, 1, 2) * .5 + .5
            output_colors = net_output['colors'].reshape(-1, P1, P2, 3).permute(0, 3, 1, 2) * .5 + .5
            vgg_loss = self.vgg(output_colors, target_colors)
            losses['vgg_loss'] = (vgg_loss, self.args.vgg_weight)

        if self.args.eikonal_weight > 0:
            losses['eik_loss'] = (net_output['eikonal-term'].mean(), self.args.eikonal_weight)
        
        # if self.args.regz_weight > 0:
        losses['reg_loss'] = (net_output['regz-term'].mean(), self.args.regz_weight)
        loss = sum(losses[key][0] * losses[key][1] for key in losses)
       
        logging_outputs = {key: item(losses[key][0]) for key in losses}
        logging_outputs.update(other_logs)
        return loss, logging_outputs
